classdef Segmentation_mAccuracy < dagnn.Loss

  properties
    confusion = {0,0,0,0,0,0}
  end

  methods
    function outputs = forward(obj, inputs, params)
      C=21;
      labels = gather(inputs{1}) ;
      mainScale=1;
      T{mainScale}=gather(inputs{mainScale+1});L=labels;
      for i=setdiff([1,2,3],mainScale)
          inputs{i+1}=gather(inputs{i+1});
          for j=1:C
            T{i}(:,:,j)=imresize(inputs{i+1}(:,:,j),size(labels));
          end
      end
      W=[0.3,0.3,0.3];
      T{4}=W(1)*T{1}+W(2)*T{2}+W(3)*T{3}+0;
      
      for  i=1:4
          [~, predictions] = max(T{i}, [], 3) ;
          % compute statistics only on accumulated pixels
          ok = L > 0 ;
          numPixels = sum(ok(:)) ;
          obj.confusion{i} = obj.confusion{i} + accumarray([L(ok), predictions(ok)],1,[C C]);
          % compute various statistics of the confusion matrix
          pos = sum(obj.confusion{i},2) ;
          res = sum(obj.confusion{i},1)' ;
          tp = diag(obj.confusion{i}) ;
          mIou(i,1) = mean(tp ./ max(1, pos + res - tp)) ;
          if i==1
              pixelAccuracy = sum(tp) / max(1,sum(obj.confusion{i}(:))) ;
              meanAccuracy = mean(tp ./ max(1, pos)) ;
          end
      end
      
      obj.average=[mIou;meanAccuracy;pixelAccuracy];
      obj.numAveraged = obj.numAveraged + numPixels ;
      outputs{1} = obj.average ;
    end

    function [derInputs, derParams] = backward(obj, inputs, params, derOutputs)
      derInputs{1} = [] ;
      derInputs{2} = [] ;
      derParams = {} ;
    end

    function reset(obj)
      obj.confusion = {0,0,0,0,0,0} ;
      obj.average = 0 ;
      obj.numAveraged = 0 ;
    end

    function obj = Segmentation_mAccuracy(varargin)
      obj.load(varargin) ;
    end
  end
end